# ------------------------------------------------------------------------------
# Produces subcategory risk stats reported in text
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 05_subcat_risk_stats.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(testit) # assert()
library(broom) # tidy()
library(data.table) # setnames()
library(OneR)

temp <- here("code", "07_other_physician_errors", "temp")
u <- modules::use(here("lib", "util.R"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here::here("lib", "filepaths.yml"))

cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  mutate(yhat_full = p__ensemble__stent_or_cabg_010_day__tested) %>%
  filter(!exclude)

# Get subcategory predictions --------------------------------------------------
message("Getting subcategory yhats...")
subcats <- c("justcc", "represent", "lvs", "dem")
for(subcat in subcats){
  predictions <- readRDS(
    file.path(
      paths$modeling$dir, "prediction", "random", subcat, "scores_test_set.rds"
    )
  ) %>%
  setnames(
    glue("p__ensemble__stent_or_cabg_010_day__tested__{subcat}"),
    glue("yhat_{subcat}")
  )

  cohort <- cohort %>% u$safe_left_join(predictions)
}

cohort <- cohort %>%
  mutate(
    # half of mass in the middle bin, so can think of 1/3 as top/bottom quartiles
    yhat_tile_justcc = bin(
      yhat_justcc, nbins = 4, method = "content"
    ) %>% factor(labels = 1:3),
    yhat_tile_dem = bin(
      yhat_dem, nbins = 4, labels = 1:4, method = "content"
    ) %>% factor,
    yhat_tile_lvs = bin(
      yhat_lvs, nbins = 4, labels = 1:4, method = "content"
    ) %>% factor
  )

# Test vs. Risk
baseline <- lm(test_010_day ~ yhat_full, cohort)
yhat_coef <- coef(baseline)['yhat_full']

# Symptom Risk -----------------------------------------------------------------
message("Symptom risk and test rate")
sym_df <- cohort %>%
  mutate(hi_justcc_yhat = yhat_tile_justcc == 3)
fit <- lm(test_010_day ~ yhat_full + hi_justcc_yhat, sym_df)
print(tidy(fit))

pct_increase <- coef(fit)["hi_justcc_yhatTRUE"]
message(
  "Increase (p.p.) in testing for those in top quartile of symptom-based risk: ",
  pct_increase
)
message(
  "Raw y-hat equivalent: ", pct_increase/yhat_coef
)


# Demographic Risk -------------------------------------------------------------
message("Demographic risk and test rate")
dem_df <- cohort %>%
  mutate(hi_dem_yhat = yhat_tile_dem == 4)

fit <- lm(test_010_day ~ yhat_full + hi_dem_yhat, dem_df)
print(tidy(fit))

pct_increase <- coef(fit)["hi_dem_yhatTRUE"]
message(
  "Increase (p.p.) in testing for those in top quartile of dem-based risk: ",
  pct_increase
)
message(
  "Raw y-hat equivalent: ", pct_increase/yhat_coef
)

# LVS Risk ---------------------------------------------------------------------
message("LVS risk and test rate")
lvs_df <- cohort %>%
  mutate(hi_lvs_yhat = yhat_tile_lvs == 4)

fit <- lm(test_010_day ~ yhat_full + hi_lvs_yhat, lvs_df)
print(tidy(fit))

pct_increase <- coef(fit)["hi_lvs_yhatTRUE"]
message(
  "Increase (p.p.) in testing for those in top quartile of lvs-based risk: ",
  pct_increase
)
message(
  "Raw y-hat equivalent: ", pct_increase/yhat_coef
)

# representative Risk ----------------------------------------------------------
message("Representative risk and test rate")
rep_df <- cohort %>%
  mutate(hi_rep_yhat = yhat_represent > median(yhat_represent))

fit <- lm(test_010_day ~ yhat_full + hi_rep_yhat, rep_df)
print(tidy(fit))

pct_increase <- coef(fit)["hi_rep_yhatTRUE"]
message(
  "Increase (p.p.) in testing for those in top (7%) of rep-based risk: ",
  pct_increase
)
message(
  "Raw y-hat equivalent: ", pct_increase/yhat_coef
)
